{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z7gEg4DRBwbO"
      },
      "source": [
        "\n",
        "# Overview\n",
        "This CodeLab demonstrates how to build a fused TFLite LSTM model for MNIST recognition using Keras, and how to convert it to TensorFlow Lite.\n",
        "\n",
        "The CodeLab is very similar to the Keras LSTM [CodeLab](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/examples/experimental_new_converter/keras_lstm.ipynb). However, we're creating fused LSTM ops rather than the unfused versoin.\n",
        "\n",
        "Also note: We're not trying to build the model to be a real world application, but only demonstrate how to use TensorFlow Lite. You can a build a much better model using CNN models. For a more canonical lstm codelab, please see [here](https://github.com/keras-team/keras/blob/master/examples/imdb_lstm.py)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1v9muouWCrLA"
      },
      "source": [
        "\n",
        "# Step 0: Prerequisites\n",
        "It's recommended to try this feature with the newest TensorFlow nightly pip build."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P0bSI6A5AWaT"
      },
      "outputs": [],
      "source": [
        "!pip install tf-nightly"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zeo4IA1xC4O8"
      },
      "source": [
        "# Step 1: Build the MNIST LSTM model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yMtp56hRBvHe"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import tensorflow as tf"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IQtjKcMYC_nD"
      },
      "outputs": [],
      "source": [
        "model = tf.keras.models.Sequential([\n",
        "    tf.keras.layers.Input(shape=(28, 28), name='input'),\n",
        "    tf.keras.layers.LSTM(20, time_major=False, return_sequences=True),\n",
        "    tf.keras.layers.Flatten(),\n",
        "    tf.keras.layers.Dense(10, activation=tf.nn.softmax, name='output')\n",
        "])\n",
        "model.compile(optimizer='adam',\n",
        "              loss='sparse_categorical_crossentropy',\n",
        "              metrics=['accuracy'])\n",
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qS_79wHVDcri"
      },
      "source": [
        "# Step 2: Train \u0026 Evaluate the model.\n",
        "We will train the model using MNIST data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-yEAraXGDlcQ"
      },
      "outputs": [],
      "source": [
        "# Load MNIST dataset.\n",
        "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
        "x_train, x_test = x_train / 255.0, x_test / 255.0\n",
        "x_train = x_train.astype(np.float32)\n",
        "x_test = x_test.astype(np.float32)\n",
        "\n",
        "# Change this to True if you want to test the flow rapidly.\n",
        "# Train with a small dataset and only 1 epoch. The model will work poorly\n",
        "# but this provides a fast way to test if the conversion works end to end.\n",
        "_FAST_TRAINING = False\n",
        "_EPOCHS = 5\n",
        "if _FAST_TRAINING:\n",
        "  _EPOCHS = 1\n",
        "  _TRAINING_DATA_COUNT = 1000\n",
        "  x_train = x_train[:_TRAINING_DATA_COUNT]\n",
        "  y_train = y_train[:_TRAINING_DATA_COUNT]\n",
        "\n",
        "model.fit(x_train, y_train, epochs=_EPOCHS)\n",
        "model.evaluate(x_test, y_test, verbose=0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5pGyWlkJDpMQ"
      },
      "source": [
        "# Step 3: Convert the Keras model to TensorFlow Lite model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tB1NZBUHDogR"
      },
      "outputs": [],
      "source": [
        "run_model = tf.function(lambda x: model(x))\n",
        "# This is important, let's fix the input size.\n",
        "BATCH_SIZE = 1\n",
        "STEPS = 28\n",
        "INPUT_SIZE = 28\n",
        "concrete_func = run_model.get_concrete_function(\n",
        "    tf.TensorSpec([BATCH_SIZE, STEPS, INPUT_SIZE], model.inputs[0].dtype))\n",
        "\n",
        "# model directory.\n",
        "MODEL_DIR = \"keras_lstm\"\n",
        "model.save(MODEL_DIR, save_format=\"tf\", signatures=concrete_func)\n",
        "\n",
        "converter = tf.lite.TFLiteConverter.from_saved_model(MODEL_DIR)\n",
        "tflite_model = converter.convert()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "INFyl-J3FAOY"
      },
      "source": [
        "# Step 4: Check the converted TensorFlow Lite model.\n",
        "Now load the TensorFlow Lite model and use the TensorFlow Lite python interpreter to verify the results."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0-b0IKK2FGuO"
      },
      "outputs": [],
      "source": [
        "# Run the model with TensorFlow to get expected results.\n",
        "TEST_CASES = 10\n",
        "\n",
        "# Run the model with TensorFlow Lite\n",
        "interpreter = tf.lite.Interpreter(model_content=tflite_model)\n",
        "interpreter.allocate_tensors()\n",
        "input_details = interpreter.get_input_details()\n",
        "output_details = interpreter.get_output_details()\n",
        "\n",
        "for i in range(TEST_CASES):\n",
        "  expected = model.predict(x_test[i:i+1])\n",
        "  interpreter.set_tensor(input_details[0][\"index\"], x_test[i:i+1, :, :])\n",
        "  interpreter.invoke()\n",
        "  result = interpreter.get_tensor(output_details[0][\"index\"])\n",
        "\n",
        "  # Assert if the result of TFLite model is consistent with the TF model.\n",
        "  np.testing.assert_almost_equal(expected, result, decimal=5)\n",
        "  print(\"Done. The result of TensorFlow matches the result of TensorFlow Lite.\")\n",
        "\n",
        "  # Please note: TfLite fused Lstm kernel is stateful, so we need to reset\n",
        "  # the states.\n",
        "  # Clean up internal states.\n",
        "  interpreter.reset_all_variables()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Cf6KC9fbFY5f"
      },
      "source": [
        "# Step 5: Let's inspect the converted TFLite model.\n",
        "\n",
        "Let's check the model, you can see the LSTM will be in it's fused format.\n",
        "\n",
        "![Fused LSTM](https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/lite/examples/experimental_new_converter/keras_lstm.png)\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "Keras LSTM fusion Codelab.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
